{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer Learning / Pretraining\n",
    "Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/transfer_learning.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "from lightning import pytorch as pl\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import torch\n",
    "\n",
    "from chemprop import data, featurizers, models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change data inputs here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n",
    "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n",
    "smiles_column = 'smiles' # name of the column containing SMILES strings\n",
    "target_columns = ['lipo'] # list of names of the columns containing targets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>lipo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14</td>\n",
       "      <td>3.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...</td>\n",
       "      <td>-1.18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl</td>\n",
       "      <td>3.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...</td>\n",
       "      <td>3.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...</td>\n",
       "      <td>3.10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...</td>\n",
       "      <td>2.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...</td>\n",
       "      <td>2.04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...</td>\n",
       "      <td>4.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>COc1ccc(Cc2c(N)n[nH]c2N)cc1</td>\n",
       "      <td>0.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...</td>\n",
       "      <td>2.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles  lipo\n",
       "0             Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14  3.54\n",
       "1   COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n",
       "2              COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl  3.69\n",
       "3   OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...  3.37\n",
       "4   Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...  3.10\n",
       "..                                                ...   ...\n",
       "95  CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...  2.20\n",
       "96  CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...  2.04\n",
       "97  CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...  4.49\n",
       "98                        COc1ccc(Cc2c(N)n[nH]c2N)cc1  0.20\n",
       "99  CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...  2.00\n",
       "\n",
       "[100 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_input = pd.read_csv(input_path)\n",
    "df_input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get SMILES and targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis = df_input.loc[:, smiles_column].values\n",
    "ys = df_input.loc[:, target_columns].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',\n",
       "       'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',\n",
       "       'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',\n",
       "       'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',\n",
       "       'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smis[:5] # show first 5 SMILES strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 3.54],\n",
       "       [-1.18],\n",
       "       [ 3.69],\n",
       "       [ 3.37],\n",
       "       [ 3.1 ]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ys[:5] # show first 5 targets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get molecule datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform data splitting for training, validation, and testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SCAFFOLD_BALANCED',\n",
       " 'RANDOM_WITH_REPEATED_SMILES',\n",
       " 'RANDOM',\n",
       " 'KENNARD_STONE',\n",
       " 'KMEANS']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# available split types\n",
    "list(data.SplitType.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change checkpoint model inputs here\n",
    "Both message-passing neural networks (MPNNs) and multi-component MPNNs can have their weights initialized from a checkpoint file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol.ckpt\" # path to the checkpoint file.\n",
    "# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnn_cls = models.MPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): GraphTransform(\n",
       "      (V_transform): Identity()\n",
       "      (E_transform): Identity()\n",
       "    )\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mpnn = mpnn_cls.load_from_file(checkpoint_path)\n",
    "mpnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scale fine-tuning data with the model's target scaler\n",
    "\n",
    "If the pre-trained model was a regression model, it probably was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For furthur training, we need to scale the fine-tuning data with the same target scaler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretraining_scaler = StandardScaler()\n",
    "pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()\n",
    "pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get MoleculeDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n",
    "\n",
    "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n",
    "train_dset.normalize_targets(pretraining_scaler)\n",
    "\n",
    "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n",
    "val_dset.normalize_targets(pretraining_scaler)\n",
    "\n",
    "test_dset = data.MoleculeDataset(test_data[0], featurizer)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = data.build_dataloader(train_dset, num_workers=num_workers)\n",
    "val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Freezing MPNN and FFN layers\n",
    "Certain layers of a pre-trained model can be kept unchanged during further training on a new task."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Freezing the MPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mpnn.message_passing.apply(lambda module: module.requires_grad_(False))\n",
    "mpnn.message_passing.eval()\n",
    "mpnn.bn.apply(lambda module: module.requires_grad_(False))\n",
    "mpnn.bn.eval()  # Set batch norm layers to eval mode to freeze running mean and running var."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Freezing FFN layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "frzn_ffn_layers = 1  # the number of consecutive FFN layers to freeze."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(frzn_ffn_layers):\n",
    "    mpnn.predictor.ffn[idx].requires_grad_(False)\n",
    "    mpnn.predictor.ffn[idx + 1].eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set up trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False,\n",
    "    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.\n",
    "    enable_progress_bar=True,\n",
    "    accelerator=\"auto\",\n",
    "    devices=1,\n",
    "    max_epochs=20, # number of epochs to train for\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/brianli/Documents/chemprop/examples/checkpoints exists and is not empty.\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 227 K  | eval \n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | BatchNorm1d        | 600    | eval \n",
      "3 | predictor       | RegressionFFN      | 90.6 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "301       Trainable params\n",
      "318 K     Non-trainable params\n",
      "318 K     Total params\n",
      "1.276     Total estimated model params size (MB)\n",
      "11        Modules in train mode\n",
      "15        Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|                       | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "/Users/brianli/Documents/chemprop/chemprop/nn/message_passing/base.py:263: UserWarning: The operator 'aten::scatter_reduce.two_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
      "  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|█████████████| 2/2 [00:00<00:00,  5.45it/s, train_loss_step=0.871]\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 48.68it/s]\u001b[A\n",
      "Epoch 1: 100%|█| 2/2 [00:00<00:00, 18.36it/s, train_loss_step=1.030, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 19.34it/s]\u001b[A\n",
      "Epoch 2: 100%|█| 2/2 [00:00<00:00, 21.85it/s, train_loss_step=1.070, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 45.83it/s]\u001b[A\n",
      "Epoch 3: 100%|█| 2/2 [00:00<00:00, 16.39it/s, train_loss_step=1.310, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 34.51it/s]\u001b[A\n",
      "Epoch 4: 100%|█| 2/2 [00:00<00:00, 17.72it/s, train_loss_step=0.817, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 46.04it/s]\u001b[A\n",
      "Epoch 5: 100%|█| 2/2 [00:00<00:00, 16.90it/s, train_loss_step=0.896, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 45.22it/s]\u001b[A\n",
      "Epoch 6: 100%|█| 2/2 [00:00<00:00, 16.48it/s, train_loss_step=1.850, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 42.72it/s]\u001b[A\n",
      "Epoch 7: 100%|█| 2/2 [00:00<00:00, 17.50it/s, train_loss_step=1.130, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 36.61it/s]\u001b[A\n",
      "Epoch 8: 100%|█| 2/2 [00:00<00:00, 16.94it/s, train_loss_step=1.500, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 42.61it/s]\u001b[A\n",
      "Epoch 9: 100%|█| 2/2 [00:00<00:00, 17.21it/s, train_loss_step=1.330, val_loss=0.\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 42.65it/s]\u001b[A\n",
      "Epoch 10: 100%|█| 2/2 [00:00<00:00, 16.14it/s, train_loss_step=1.570, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 20.81it/s]\u001b[A\n",
      "Epoch 11: 100%|█| 2/2 [00:00<00:00, 19.93it/s, train_loss_step=1.190, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 35.92it/s]\u001b[A\n",
      "Epoch 12: 100%|█| 2/2 [00:00<00:00, 14.66it/s, train_loss_step=0.464, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 39.19it/s]\u001b[A\n",
      "Epoch 13: 100%|█| 2/2 [00:00<00:00, 15.37it/s, train_loss_step=0.389, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 42.41it/s]\u001b[A\n",
      "Epoch 14: 100%|█| 2/2 [00:00<00:00, 16.10it/s, train_loss_step=0.776, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 38.58it/s]\u001b[A\n",
      "Epoch 15: 100%|█| 2/2 [00:00<00:00, 14.94it/s, train_loss_step=0.381, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 34.09it/s]\u001b[A\n",
      "Epoch 16: 100%|█| 2/2 [00:00<00:00, 15.54it/s, train_loss_step=0.876, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 39.06it/s]\u001b[A\n",
      "Epoch 17: 100%|█| 2/2 [00:00<00:00, 15.96it/s, train_loss_step=1.320, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 40.17it/s]\u001b[A\n",
      "Epoch 18: 100%|█| 2/2 [00:00<00:00, 16.13it/s, train_loss_step=1.110, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 39.27it/s]\u001b[A\n",
      "Epoch 19: 100%|█| 2/2 [00:00<00:00, 16.39it/s, train_loss_step=1.060, val_loss=0\u001b[A\n",
      "Validation: |                                             | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                         | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 17.26it/s]\u001b[A\n",
      "Epoch 19: 100%|█| 2/2 [00:00<00:00, 10.26it/s, train_loss_step=1.060, val_loss=0\u001b[A"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|█| 2/2 [00:00<00:00,  9.91it/s, train_loss_step=1.060, val_loss=0\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(mpnn, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|███████████████████████| 1/1 [00:00<00:00, 21.60it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mse          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9625480771064758     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mse         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9625480771064758    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(mpnn, test_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer learning with multicomponenent models\n",
    "Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Change data inputs here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol+mol.ckpt\"  # path to the checkpoint file. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Change checkpoint model inputs here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MulticomponentMPNN(\n",
       "  (message_passing): MulticomponentMessagePassing(\n",
       "    (blocks): ModuleList(\n",
       "      (0-1): 2 x BondMessagePassing(\n",
       "        (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "        (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "        (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "        (tau): ReLU()\n",
       "        (V_d_transform): Identity()\n",
       "        (graph_transform): GraphTransform(\n",
       "          (V_transform): Identity()\n",
       "          (E_transform): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=600, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mpnn_cls = models.MulticomponentMPNN\n",
    "mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)\n",
    "mcmpnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "blocks_to_freeze = [0, 1]  # a list of indices of the individual MPNN blocks to freeze before training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)\n",
    "for i in blocks_to_freeze:\n",
    "    mp_block = mcmpnn.message_passing.blocks[i]\n",
    "    mp_block.apply(lambda module: module.requires_grad_(False))\n",
    "    mp_block.eval()\n",
    "mcmpnn.bn.apply(lambda module: module.requires_grad_(False))\n",
    "mcmpnn.bn.eval()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
